{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)"
      ],
      "metadata": {
        "id": "YfAj9wQV3DCF"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/https://github.com/JohnSnowLabs/nlu/tree/master/examples/colab/component_examples/classifiers/Roberta_Zero_Shot_Classifier.ipynb)"
      ],
      "metadata": {
        "id": "gLOVuUuL3Ysb"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Zero Shot Text Classification\n",
        "\n",
        "State-of-the-art NLP models for text classification without annotated data\n",
        "\n",
        "Natural language processing is a very exciting field right now. In recent years, the community has begun to figure out some pretty effective methods of learning from the enormous amounts of unlabeled data available on the internet. The success of transfer learning from unsupervised models has allowed us to surpass virtually all existing benchmarks on downstream supervised learning tasks. As we continue to develop new model architectures and unsupervised learning objectives, \"state of the art\" continues to be a rapidly moving target for many tasks where large amounts of labeled data are available.\n",
        "\n",
        "### Zero Shot learning\n",
        "\n",
        "Zero-shot Learning (ZSL) is one of the most recent advancements in Machine Learning aimed to train Deep Neural Network models to have higher generalisability on unseen data. One of the most prominent methods of training such models is to use text prompts that explain the task to be solved, along with all possible outputs.\n",
        "\n",
        "The primary aim of using ZSL over supervised learning is to address the following limitations of training traditional supervised learning models:\n",
        "\n",
        "1. Training supervised NLP models require substantial amount of training data.\n",
        "2. Even with recent trend of fine-tuning large language models, the supervised approach of training or fine-tuning a model is basically to learn a very specific data distribution, which results in low performance when applied to diverse and unseen data.\n",
        "3. The classical annotate-train-test cycle is highly demanding in terms of temporal and human resources."
      ],
      "metadata": {
        "id": "tz13nqEF3Rh8"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Roberta Zero Shot Classifier\n",
        "\n",
        "This model is intended to be used for zero-shot text classification, especially in English. It is fine-tuned on NLI by using Roberta Base model.\n",
        "\n",
        "RoBertaForZeroShotClassificationusing a ModelForSequenceClassification trained on NLI (natural language inference) tasks. Equivalent of RoBertaForZeroShotClassification models, but these models don’t require a hardcoded number of potential classes, they can be chosen at runtime. It usually means it’s slower but it is much more flexible.\n",
        "\n",
        "We used TFRobertaForSequenceClassification to train this model and used RoBertaForZeroShotClassification annotator in Spark NLP 🚀 for prediction at scale!"
      ],
      "metadata": {
        "id": "q253sILp3UGa"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 693
        },
        "id": "6pScZG6m20Ek",
        "outputId": "90646ed1-8807-43cf-d7c6-9ead5a554357"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting nlu\n",
            "  Downloading nlu-4.2.0-py3-none-any.whl (646 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m646.3/646.3 kB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting spark-nlp>=4.2.0 (from nlu)\n",
            "  Downloading spark_nlp-5.0.1-py2.py3-none-any.whl (499 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m499.0/499.0 kB\u001b[0m \u001b[31m30.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from nlu) (1.22.4)\n",
            "Requirement already satisfied: pyarrow>=0.16.0 in /usr/local/lib/python3.10/dist-packages (from nlu) (9.0.0)\n",
            "Requirement already satisfied: pandas>=1.3.5 in /usr/local/lib/python3.10/dist-packages (from nlu) (1.5.3)\n",
            "Collecting dataclasses (from nlu)\n",
            "  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)\n",
            "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->nlu) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->nlu) (2022.7.1)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas>=1.3.5->nlu) (1.16.0)\n",
            "Installing collected packages: spark-nlp, dataclasses, nlu\n",
            "Successfully installed nlu-4.2.0 dataclasses-0.6 spark-nlp-5.0.1\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "dataclasses"
                ]
              }
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting pyspark==3.0.1\n",
            "  Downloading pyspark-3.0.1.tar.gz (204.2 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m204.2/204.2 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Collecting py4j==0.10.9 (from pyspark==3.0.1)\n",
            "  Downloading py4j-0.10.9-py2.py3-none-any.whl (198 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m198.6/198.6 kB\u001b[0m \u001b[31m16.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hBuilding wheels for collected packages: pyspark\n",
            "  Building wheel for pyspark (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pyspark: filename=pyspark-3.0.1-py2.py3-none-any.whl size=204612226 sha256=b624e62124c5ec3da3cf943c5262441c47f1f0111a904a8512e80a5d9a732da9\n",
            "  Stored in directory: /root/.cache/pip/wheels/19/b0/c8/6cb894117070e130fc44352c2a13f15b6c27e440d04a84fb48\n",
            "Successfully built pyspark\n",
            "Installing collected packages: py4j, pyspark\n",
            "  Attempting uninstall: py4j\n",
            "    Found existing installation: py4j 0.10.9.7\n",
            "    Uninstalling py4j-0.10.9.7:\n",
            "      Successfully uninstalled py4j-0.10.9.7\n",
            "Successfully installed py4j-0.10.9 pyspark-3.0.1\n"
          ]
        }
      ],
      "source": [
        "!pip install nlu\n",
        "!pip install pyspark==3.0.1"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import nlu\n",
        "import pandas as pd"
      ],
      "metadata": {
        "id": "tlCKungX3I5s"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "text = [\"I have a problem with my iphone that needs to be resolved asap.\",\n",
        "        \"Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app.\",\n",
        "        \"I have a phone and I love it.\",\n",
        "        \"I really want to visit Germany and I am planning to go there next year.\",\n",
        "        \"I loved this movie when I was a child.\",\n",
        "        \"I always hated this movie and it's plot.\",\n",
        "        \"The product is not worth it's hype hence suggest it not buying it.\",\n",
        "        \"This sounds like a good plan, let's stick to it and sort this sprint.\",\n",
        "        \"Sound is terrible if u want good music too get a bose\",\n",
        "        \"Works great, upgraded from echo dot to full size echo, couldn’t be happier!\"\n",
        "        ]"
      ],
      "metadata": {
        "id": "5ldx8EtU3KVM"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "roberta_zero_shot = nlu.load(\"en.roberta.zero_shot_classifier\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "X7bHv8MX3Lrr",
        "outputId": "58f8878a-041d-4af2-a67c-e89e97af7456"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Warning::Spark Session already created, some configs may not take.\n",
            "roberta_base_zero_shot_classifier_nli download started this may take some time.\n",
            "Approximate size to download 444.8 MB\n",
            "[OK!]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "results_roberta = roberta_zero_shot.predict(text, output_level = 'sentence')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1F3yvI-D3N01",
        "outputId": "d87f0471-ff8d-43c5-a43f-db17ba0272ea"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "sentence_detector_dl download started this may take some time.\n",
            "Approximate size to download 354.6 KB\n",
            "[OK!]\n",
            "Warning::Spark Session already created, some configs may not take.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "results_roberta"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 363
        },
        "id": "pUGz-04x3PQi",
        "outputId": "c7b77b02-601d-4ed4-c92b-e949eda1b15b"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "  classified_sequence classified_sequence_confidence  \\\n",
              "0              mobile                       0.178295   \n",
              "1              travel                        0.15382   \n",
              "2              mobile                       0.141618   \n",
              "3             weather                       0.190693   \n",
              "4             weather                       0.141874   \n",
              "5              mobile                       0.135632   \n",
              "6              mobile                       0.177482   \n",
              "7              mobile                         0.1595   \n",
              "8              mobile                       0.146786   \n",
              "9              mobile                       0.315334   \n",
              "\n",
              "                                            sentence  \n",
              "0  I have a problem with my iphone that needs to ...  \n",
              "1  Last week I upgraded my iOS version and ever s...  \n",
              "2                      I have a phone and I love it.  \n",
              "3  I really want to visit Germany and I am planni...  \n",
              "4             I loved this movie when I was a child.  \n",
              "5           I always hated this movie and it's plot.  \n",
              "6  The product is not worth it's hype hence sugge...  \n",
              "7  This sounds like a good plan, let's stick to i...  \n",
              "8  Sound is terrible if u want good music too get...  \n",
              "9  Works great, upgraded from echo dot to full si...  "
            ],
            "text/html": [
              "\n",
              "\n",
              "  <div id=\"df-1a37ac58-05b3-48d6-a5b2-9cc596b22eb6\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>classified_sequence</th>\n",
              "      <th>classified_sequence_confidence</th>\n",
              "      <th>sentence</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>mobile</td>\n",
              "      <td>0.178295</td>\n",
              "      <td>I have a problem with my iphone that needs to ...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>travel</td>\n",
              "      <td>0.15382</td>\n",
              "      <td>Last week I upgraded my iOS version and ever s...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>mobile</td>\n",
              "      <td>0.141618</td>\n",
              "      <td>I have a phone and I love it.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>weather</td>\n",
              "      <td>0.190693</td>\n",
              "      <td>I really want to visit Germany and I am planni...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>weather</td>\n",
              "      <td>0.141874</td>\n",
              "      <td>I loved this movie when I was a child.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>mobile</td>\n",
              "      <td>0.135632</td>\n",
              "      <td>I always hated this movie and it's plot.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>mobile</td>\n",
              "      <td>0.177482</td>\n",
              "      <td>The product is not worth it's hype hence sugge...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>mobile</td>\n",
              "      <td>0.1595</td>\n",
              "      <td>This sounds like a good plan, let's stick to i...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>mobile</td>\n",
              "      <td>0.146786</td>\n",
              "      <td>Sound is terrible if u want good music too get...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9</th>\n",
              "      <td>mobile</td>\n",
              "      <td>0.315334</td>\n",
              "      <td>Works great, upgraded from echo dot to full si...</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-1a37ac58-05b3-48d6-a5b2-9cc596b22eb6')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "\n",
              "\n",
              "\n",
              "    <div id=\"df-5903229a-cbad-447a-a093-273ff66dda46\">\n",
              "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-5903229a-cbad-447a-a093-273ff66dda46')\"\n",
              "              title=\"Suggest charts.\"\n",
              "              style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "      </button>\n",
              "    </div>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "    background-color: #E8F0FE;\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: #1967D2;\n",
              "    height: 32px;\n",
              "    padding: 0 0 0 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: #E2EBFA;\n",
              "    box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: #174EA6;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "    background-color: #3B4455;\n",
              "    fill: #D2E3FC;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart:hover {\n",
              "    background-color: #434B5C;\n",
              "    box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "    filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "    fill: #FFFFFF;\n",
              "  }\n",
              "</style>\n",
              "\n",
              "    <script>\n",
              "      async function quickchart(key) {\n",
              "        const containerElement = document.querySelector('#' + key);\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      }\n",
              "    </script>\n",
              "\n",
              "      <script>\n",
              "\n",
              "function displayQuickchartButton(domScope) {\n",
              "  let quickchartButtonEl =\n",
              "    domScope.querySelector('#df-5903229a-cbad-447a-a093-273ff66dda46 button.colab-df-quickchart');\n",
              "  quickchartButtonEl.style.display =\n",
              "    google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "}\n",
              "\n",
              "        displayQuickchartButton(document);\n",
              "      </script>\n",
              "      <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-1a37ac58-05b3-48d6-a5b2-9cc596b22eb6 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-1a37ac58-05b3-48d6-a5b2-9cc596b22eb6');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ]
    }
  ]
}